import os
import numpy as np
import pandas as pd

# =============================================================================
# Function to load the runtime data and compute statistics for error bars.
# Here, for each linear folder (i.e. each setting) and for each method,
# we compute the median and the error as half the interquartile range (IQR).
#
# Sources:
# - NumPy documentation: https://numpy.org/doc/stable/
# - Pandas documentation: https://pandas.pydata.org/docs/
# =============================================================================
def load_runtime_stats(dataset_path, linear_dirs, method_list, metric="matrix_times"):
    """
    Load runtime data for the given dataset and compute median and error bar (half IQR)
    for each method across different linear proportion folders.

    Parameters:
      - dataset_path: Path to the dataset folder.
      - linear_dirs: List of folder names for different linear proportions.
      - method_list: List of method names.
      - metric: Metric to load (default is "matrix_times").

    Returns:
      A dictionary where keys are simplified linear proportion labels and values are
      dictionaries mapping each method to a tuple (median, error).
    """
    stats = {}
    for lin_dir in linear_dirs:
        lin_path = os.path.join(dataset_path, lin_dir)
        # Simplify the label (e.g. "linear_proportion_0.5" becomes "0.5")
        label = lin_dir.replace("linear_proportion_", "")
        stats[label] = {}
        for method in method_list:
            file_path = os.path.join(lin_path, f"{method}_{metric}.npy")
            if os.path.isfile(file_path):
                data = np.load(file_path)
                median = np.median(data)
                q25 = np.percentile(data, 25)
                q75 = np.percentile(data, 75)
                error = (q75 - q25) / 2.0  # half the interquartile range (IQR)
                stats[label][method] = (median, error)
            else:
                stats[label][method] = (np.nan, np.nan)
                print(f"Warning: {file_path} not found. Skipping this entry.")
    return stats

def create_runtime_table_transposed(dataset_path, linear_dirs, method_list, metric="matrix_times"):
    stats = load_runtime_stats(dataset_path, linear_dirs, method_list, metric)
    
    # Create a DataFrame with methods as rows and linear proportions as columns.
    # Initialize an empty DataFrame with methods as the index.
    table = pd.DataFrame(
        index=method_list,
        columns=sorted([ld.replace("linear_proportion_", "") for ld in linear_dirs], key=float)
    )
    
    for lin_dir in linear_dirs:
        label = lin_dir.replace("linear_proportion_", "")
        for method in method_list:
            median, error = stats[label][method]
            if np.isnan(median) or np.isnan(error):
                table.loc[method, label] = "N/A"
            else:
                table.loc[method, label] = f"{median:.3f} ± {error:.3f}"
    
    # Rename the row for "TDLHD" to "LoSAM" for display purposes.
    table = table.rename(index={"TDLHD": "LoSAM"})
    
    return table

# Example usage:
if __name__ == "__main__":
    dataset_path = os.path.join(root, "uniformd25ER1n1000")
    linear_folders = [
        "linear_proportion_0", 
        "linear_proportion_0.25", 
        "linear_proportion_0.5",
        "linear_proportion_0.75", 
        "linear_proportion_1"
    ]
    linear_folders = [
        "linear_proportion_0.5"
    ]
    method_list = ["TDLHD", "CAPS", "NHTS", "DLiNGAM", "RESIT", "SCORE", "NoGAM", "CAM", "RandSort", "VarSort"]
    method_list = ["TDLHD", "DLiNGAM", "RandSort", "VarSort"]
    
    transposed_table = create_runtime_table_transposed(dataset_path, linear_folders, method_list, metric="matrix_times")
    
    print("Runtime Table (Methods as rows):")
    print(transposed_table.to_string())
